{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "deb93f66",
   "metadata": {},
   "source": [
    "# Pytorch Lightning\n",
    "\n",
    "PyTorch has all you need to train your models; however, there’s much more to deep learning than attaching layers. When it comes to the actual training, there’s a lot of boilerplate code that one needs  to write. We have already seen this in our previous examples. This includes things like transferring data from CPU to GPU, implementing the training driver, etc.  Additionally, if one needs to scale training/inferencing on multiple devices/machines, there’s another set of integrations that often need to be done.\n",
    "\n",
    "PyTorch Lightning is a solution that provides the APIs required to build models, datasets, and so on. The idea is that Lightning leaves the research logic to you while automating the rest of the boilerplate code. Additionally, features like multi-GPU training, FP16, training on TPU are brought in inherently by Lightning without requiring any code changes.\n",
    "\n",
    "More details about PyTorch Lightning can be found at https://www.pytorchlightning.ai/tutorials\n",
    "\n",
    "We adapt PyTorch Lightning for the rest of the chapters whenever we are writing code to train models. This helps us make the code more succinct and precise. \n",
    "\n",
    "In this context, let us revisit the problem of digit classification and show how it can be implemented within the Lightning framework. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ddc81a20",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchmetrics\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "from pytorch_lightning import LightningDataModule, LightningModule, Trainer\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "from torchvision import datasets, transforms\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "72793998",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "42"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pl.seed_everything(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f33389d",
   "metadata": {},
   "source": [
    "## DataModule\n",
    "\n",
    "A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data. All datamodules must inherit from LightningDataModule which provides methods to be overriden. \n",
    "\n",
    "In this specific case, we will implement MNIST as a datamodule. This datamodule can now be used across multiple experiments spanning different models, architectures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2ecb0531",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MNISTDataModule(LightningDataModule):\n",
    "    DATASET_DIR = \"datasets\"\n",
    "    \n",
    "    def __init__(self, transform=None, batch_size=100):\n",
    "        super(MNISTDataModule, self).__init__()\n",
    "        if transform is None:\n",
    "            # Default transform\n",
    "            transform = transforms.Compose([transforms.Resize((32, 32)),\n",
    "                                 transforms.ToTensor()])\n",
    "        self.transform = transform\n",
    "        self.batch_size = batch_size\n",
    "\n",
    "    \n",
    "    def prepare_data(self):\n",
    "        \"\"\"\n",
    "        All the steps needed to download, tokenize, prepare the raw data should be done under\n",
    "        prepare data. We will download the MNIST dataset here.\n",
    "        \"\"\"\n",
    "        # Download the train data\n",
    "        datasets.MNIST(root = MNISTDataModule.DATASET_DIR, train = True, download = True)\n",
    "               \n",
    "        # Download the test data\n",
    "        datasets.MNIST(root = MNISTDataModule.DATASET_DIR, train = False, download = True)\n",
    "    \n",
    "    def setup(self, stage=None):\n",
    "        \"\"\"\n",
    "        The steps to setup the dataset are usually done under setup method. \n",
    "        \"\"\"\n",
    "        train_dataset = datasets.MNIST(root = MNISTDataModule.DATASET_DIR, train = True, \n",
    "                                            download = False, transform=self.transform)\n",
    "        # We will split the train dataset into train and validation sets.\n",
    "        # All experiments are run using the train and val datasets\n",
    "        self.train_dataset, self.val_dataset = random_split(train_dataset, [55000, 5000])\n",
    "        self.test_dataset = datasets.MNIST(root = MNISTDataModule.DATASET_DIR, train = False, \n",
    "                                            download = False, transform=self.transform)\n",
    "    \n",
    "    \n",
    "    def train_dataloader(self):\n",
    "        \"\"\"\n",
    "        As evident by the name, this method is responsible for creating and returning the \n",
    "        train dataloader\n",
    "        \"\"\"\n",
    "        return DataLoader(self.train_dataset, batch_size=self.batch_size, \n",
    "                          shuffle=True, num_workers=0) \n",
    "    \n",
    "    def val_dataloader(self):\n",
    "        \"\"\"\n",
    "        As evident by the name, this method is responsible for creating and returning the \n",
    "        val dataloader\n",
    "        \"\"\"\n",
    "        return DataLoader(self.val_dataset, batch_size=self.batch_size, \n",
    "                          shuffle=False, num_workers=0) \n",
    "    \n",
    "    def test_dataloader(self):\n",
    "        \"\"\"\n",
    "        As evident by the name, this method is responsible for creating and returning the \n",
    "        val dataloader\n",
    "        \"\"\"\n",
    "        return DataLoader(self.test_dataset, batch_size=self.batch_size, \n",
    "                                          shuffle=False, num_workers=0)\n",
    "    \n",
    "    @property\n",
    "    def num_classes(self):\n",
    "        return 10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23887465",
   "metadata": {},
   "source": [
    "## LightningModule\n",
    "\n",
    "A LightningModule organizes your PyTorch code into 5 sections\n",
    "\n",
    "1. Computations (init).\n",
    "2. Train loop (training_step)\n",
    "3. Validation loop (validation_step)\n",
    "4. Test loop (test_step)\n",
    "5. Optimizers (configure_optimizers)\n",
    "\n",
    "Let us now see how we can define the LeNet classifier as a LightningModule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "32e74f06",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LeNetClassifier(LightningModule):\n",
    "    def __init__(self, num_classes):\n",
    "        \"\"\"\n",
    "        In __init__ we typically define the model, the criterion and any other setup steps needed to be done\n",
    "        for the training of the model.\n",
    "        \"\"\"        \n",
    "        super(LeNetClassifier, self).__init__()\n",
    "        self.save_hyperparameters()\n",
    "        \n",
    "        self.conv1 = torch.nn.Sequential(\n",
    "                        torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),\n",
    "                        torch.nn.Tanh(),\n",
    "                        torch.nn.AvgPool2d(kernel_size=2))\n",
    "        self.conv2 = torch.nn.Sequential(\n",
    "                        torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),\n",
    "                        torch.nn.Tanh(),\n",
    "                        torch.nn.AvgPool2d(kernel_size=2))\n",
    "        self.conv3 = torch.nn.Sequential(\n",
    "                        torch.nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),\n",
    "                        torch.nn.Tanh())\n",
    "        self.fc1 = torch.nn.Sequential(\n",
    "                        torch.nn.Linear(in_features=120, out_features=84),\n",
    "                        torch.nn.Tanh())\n",
    "        self.fc2 = torch.nn.Linear(in_features=84, out_features=num_classes)\n",
    "        \n",
    "        # We will use Cross Entropy Loss\n",
    "        self.criterion = torch.nn.CrossEntropyLoss()\n",
    "        \n",
    "        self.accuracy = torchmetrics.Accuracy()\n",
    "\n",
    "        \n",
    "    def forward(self, X):\n",
    "        \"\"\"\n",
    "        The forward method implements the forward pass of the model. In this case\n",
    "        the input is a batch of images, and the output is the logits\n",
    "        \"\"\"\n",
    "        # X is [batch_size, C, H, W] tensor\n",
    "        conv_out = self.conv3(self.conv2(self.conv1(X)))\n",
    "        batch_size = conv_out.shape[0]\n",
    "        conv_out = conv_out.reshape(batch_size, -1)\n",
    "        # Logits is [batch_size, num_classes] tensor\n",
    "        logits = self.fc2(self.fc1(conv_out))\n",
    "        return logits  \n",
    "    \n",
    "    def predict(self, X):\n",
    "        \"\"\"\n",
    "        Predict runs the forward pass, performs softmax to convert the resulting logits into\n",
    "        probabilities and returns the class with the highest probability.\n",
    "        \"\"\"\n",
    "        logits = self.forward(X)\n",
    "        probs = torch.softmax(logits, dim=1)\n",
    "        return torch.argmax(probs, 1)\n",
    "        \n",
    "        \n",
    "    def core_step(self, batch):\n",
    "        \"\"\"\n",
    "        Both the training and test loops involve the forward pass, computation of loss\n",
    "        and accuracy. Let us abstract it out and implement it under this method.\n",
    "        \"\"\"\n",
    "        X, y_true = batch\n",
    "        y_pred_logits = self.forward(X)\n",
    "        loss = self.criterion(y_pred_logits, y_true)\n",
    "        accuracy = self.accuracy(y_pred_logits, y_true)\n",
    "        return loss, accuracy\n",
    "        \n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        \"\"\"\n",
    "        This method implements the basic training step. We will run forward pass, compute \n",
    "        loss, accuracy. We will log any necessary values, and return the total loss.\n",
    "        \"\"\"\n",
    "        loss, accuracy = self.core_step(batch)\n",
    "        if self.global_step % 100 == 0:\n",
    "            self.log(\"train_loss\", loss, on_step=True, on_epoch=True)\n",
    "            self.log(\"train_accuracy\", accuracy, on_step=True, on_epoch=True)\n",
    "        return loss\n",
    "    \n",
    "    def validation_step(self, batch, batch_idx, dataset_idx=None):\n",
    "        \"\"\"\n",
    "        This method implements the basic validation step. We will run the forward pass, compute the loss\n",
    "        and accuracy and return it.\n",
    "        \"\"\"\n",
    "        return self.core_step(batch)\n",
    "    \n",
    "    def validation_epoch_end(self, outputs):\n",
    "        \"\"\"\n",
    "        This method will be called at the end of all test steps for each epoch i.e the validation epoch end.\n",
    "        The output of every single test_step is available to via outputs. \n",
    "        \n",
    "        Here we will compute the average test loss and accuracy by simply averaging across all test batches\n",
    "        \"\"\"\n",
    "        avg_loss = torch.tensor([x[0] for x in outputs]).mean()\n",
    "        avg_accuracy = torch.tensor([x[1] for x in outputs]).mean()\n",
    "        self.log(\"val_loss\", avg_loss)\n",
    "        self.log(\"val_accuracy\", avg_accuracy)\n",
    "        print(f\"Epoch {self.current_epoch}, Val loss: {avg_loss:0.2f}, Accuracy: {avg_accuracy:0.2f}\")\n",
    "        return avg_loss\n",
    "    \n",
    "    def configure_optimizers(self):\n",
    "        \"\"\"\n",
    "        The optimizer will be configured in this method\n",
    "        \"\"\"\n",
    "        return torch.optim.SGD(model.parameters(), lr=0.01,\n",
    "                      momentum=0.9)\n",
    "    \n",
    "    def checkpoint_callback(self):\n",
    "        \"\"\"\n",
    "        This callback determines the logic for how we want to checkpoint / save the model\n",
    "        \"\"\"\n",
    "        # We will save the model with the best val accuracy.\n",
    "        return ModelCheckpoint(monitor=\"val_accuracy\", mode=\"max\", save_top_k=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd4c11eb",
   "metadata": {},
   "source": [
    "Notice how the model is independent of the data. This will allow us to potentially run the LeNetClassifier model on other data modules without any code change. \n",
    "\n",
    "Note that we are not doing the following steps\n",
    "1. Moving the data to device\n",
    "2. Calling loss.backward \n",
    "3. Calling optimizer.backward\n",
    "4. Setting model.train() / eval()\n",
    "5. Resetting the gradients\n",
    "6. Implementing the trainer loop\n",
    "\n",
    "All of these are taken care of by PyTorch Lightning, thus eliminating a lot of boiler plate code."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78e2de70",
   "metadata": {},
   "source": [
    "## Trainer\n",
    "\n",
    "Now we are ready to train our model. This can be done using the Trainer class.\n",
    "\n",
    "\n",
    "This abstraction achieves the following:\n",
    "\n",
    "1. You maintain control over all aspects via PyTorch code without an added abstraction.\n",
    "2. The trainer uses best practices embedded by contributors and users from top AI labs such as Facebook AI Research, NYU, MIT, Stanford, etc…\n",
    "3. The trainer allows overriding any key part that you don’t want automated.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "64540e12",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n"
     ]
    }
   ],
   "source": [
    "dm = MNISTDataModule()\n",
    "model = LeNetClassifier(num_classes=dm.num_classes)\n",
    "exp_dir = \"/tmp/mnist\"\n",
    "trainer = Trainer(\n",
    "        default_root_dir=exp_dir, # The experiment directory\n",
    "        callbacks=[model.checkpoint_callback()],\n",
    "        gpus=torch.cuda.device_count(), # Number of GPUs to run on\n",
    "        max_epochs=10,\n",
    "        num_sanity_val_steps=0\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3165f04c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  | Name      | Type             | Params\n",
      "-----------------------------------------------\n",
      "0 | conv1     | Sequential       | 156   \n",
      "1 | conv2     | Sequential       | 2.4 K \n",
      "2 | conv3     | Sequential       | 48.1 K\n",
      "3 | fc1       | Sequential       | 10.2 K\n",
      "4 | fc2       | Linear           | 850   \n",
      "5 | criterion | CrossEntropyLoss | 0     \n",
      "6 | accuracy  | Accuracy         | 0     \n",
      "-----------------------------------------------\n",
      "61.7 K    Trainable params\n",
      "0         Non-trainable params\n",
      "61.7 K    Total params\n",
      "0.247     Total estimated model params size (MB)\n",
      "/Users/sujaynarumanchi/.virtualenvs/book_test_pyt113/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
      "  rank_zero_warn(\n",
      "/Users/sujaynarumanchi/.virtualenvs/book_test_pyt113/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
      "  rank_zero_warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fa7ef526809d4d79877cceefb147520b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, Val loss: 0.24, Accuracy: 0.93\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Val loss: 0.14, Accuracy: 0.96\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2, Val loss: 0.10, Accuracy: 0.97\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3, Val loss: 0.08, Accuracy: 0.98\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4, Val loss: 0.07, Accuracy: 0.98\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5, Val loss: 0.06, Accuracy: 0.98\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6, Val loss: 0.06, Accuracy: 0.98\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7, Val loss: 0.06, Accuracy: 0.98\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8, Val loss: 0.05, Accuracy: 0.98\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validating: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9, Val loss: 0.05, Accuracy: 0.99\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(model, dm)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d52f8546",
   "metadata": {},
   "source": [
    "Note that we did not to write the trainer loop either. We just need to call trainer.fit to train the model. \n",
    "\n",
    "Additionally, the logging automatically enables us to look at the loss and accuracy curves via TensorBoard. This can be done by running `tensorboad --logdir /tmp/mnist` "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "71a06ffb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inference follows a similar pattern as before. \n",
    "\n",
    "model.eval()\n",
    "X, y_true = next(iter(dm.test_dataloader()))\n",
    "with torch.no_grad():\n",
    "    y_pred = model.predict(X)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
